{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "colab": {
   "name": "Cifar10_Ax_hyperparam_tuning.ipynb",
   "provenance": []
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "m4TggyOPfapy"
   },
   "source": [
    "# FastResNet Hyperparameters tuning with [Ax](https://ax.dev/) on CIFAR10\n",
    "\n",
    "In this notebook we provide an example of hyperparameter tuning with [Ax](https://ax.dev/) package. We will train a ResNet model from [awesome repository of David Page](https://github.com/davidcpage/cifar10-fast) on CIFAR10.\n",
    "\n",
    "\n",
    "### Why Ax ?\n",
    "\n",
    "This is a good question ... Maybe this page could better answer : https://ax.dev/docs/why-ax.html\n",
    "\n",
    "> Ax is a platform for optimizing any kind of experiment, including machine learning experiments, A/B tests, and simulations. Ax can optimize discrete configurations (e.g., variants of an A/B test) using multi-armed bandit optimization, and continuous (e.g., integer or floating point)-valued configurations using Bayesian optimization. This makes it suitable for a wide range of applications.\n",
    "\n",
    "There are also interesting packages as [ray-tune](https://ray.readthedocs.io/en/latest/tune.html), [optuna](https://github.com/pfnet/optuna) and many others. As a side note, optuna provides an example with Ignite [here](https://github.com/pfnet/optuna/blob/master/examples/ignite_simple.py).\n",
    "\n",
    "\n",
    "### Fast ResNet model\n",
    "\n",
    "We will reimplement a resnet model from David Page's [cifar-10 repository](https://github.com/davidcpage/cifar10-fast) which trains very fast (94% of test accuracy in 26 second on NVidia V100). For sake of simplicity, we will not apply all preprocessing used in the repository (please see [bag-of-trick notebook](https://github.com/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb) for details).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cRMcUAbLfap7"
   },
   "source": [
    "### Setup dependencies\n",
    "\n",
    "Please install \n",
    "- `torchvision`\n",
    "- `Ax`\n",
    "- `tensorboard`\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "zSAI-Dknfap-"
   },
   "source": [
    "!pip install pytorch-ignite tensorboardX ax-platform"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "S_q9Q0zIfaqK"
   },
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"../../\")"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "SwOlZhS6faqM"
   },
   "source": [
    "import torch\n",
    "import ignite\n",
    "\n",
    "torch.__version__, ignite.__version__"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "kWGvqsNofaqP"
   },
   "source": [
    "### Setup model\n",
    "\n",
    "Cifar10-fast model is inspired of ResNet family models and in order to run fast it uses various tricks like:\n",
    "- `conv + batch norm + activation + pool` -> `conv + pool + batch norm + activation`\n",
    "- `batchnorm` -> `ghost batchnorm` -> `frozen ghost batchnorm`\n",
    "- `ReLU` -> `CeLU`\n",
    "- data whitening as convolution non-learnable operation (we will not implement it)\n",
    "\n",
    "Network architecture looks like this:\n",
    "\n",
    "![fastresnet](https://github.com/abdulelahsm/ignite/blob/update-tutorials/examples/notebooks/assets/fastresnet_v2.svg?raw=1)\n",
    "\n",
    "\n",
    "Please see [bag-of-trick notebook](https://github.com/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb) for more detail.\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "Ga6ykOK4faqR"
   },
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class GhostBatchNorm(nn.BatchNorm2d):\n",
    "    \"\"\"\n",
    "    From : https://github.com/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb\n",
    "\n",
    "    Batch norm seems to work best with batch size of around 32. The reasons presumably have to do \n",
    "    with noise in the batch statistics and specifically a balance between a beneficial regularising effect \n",
    "    at intermediate batch sizes and an excess of noise at small batches.\n",
    "    \n",
    "    Our batches are of size 512 and we can't afford to reduce them without taking a serious hit on training times, \n",
    "    but we can apply batch norm separately to subsets of a training batch. This technique, known as 'ghost' batch \n",
    "    norm, is usually used in a distributed setting but is just as useful when using large batches on a single node. \n",
    "    It isn't supported directly in PyTorch but we can roll our own easily enough.\n",
    "    \"\"\"\n",
    "    def __init__(self, num_features, num_splits, eps=1e-05, momentum=0.1, weight=True, bias=True):\n",
    "        super(GhostBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum)\n",
    "        self.weight.data.fill_(1.0)\n",
    "        self.bias.data.fill_(0.0)\n",
    "        self.weight.requires_grad = weight\n",
    "        self.bias.requires_grad = bias        \n",
    "        self.num_splits = num_splits\n",
    "        self.register_buffer('running_mean', torch.zeros(num_features*self.num_splits))\n",
    "        self.register_buffer('running_var', torch.ones(num_features*self.num_splits))\n",
    "\n",
    "    def train(self, mode=True):\n",
    "        if (self.training is True) and (mode is False):\n",
    "            self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits)\n",
    "            self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0).repeat(self.num_splits)\n",
    "        return super(GhostBatchNorm, self).train(mode)\n",
    "        \n",
    "    def forward(self, input):\n",
    "        N, C, H, W = input.shape\n",
    "        if self.training or not self.track_running_stats:\n",
    "            return F.batch_norm(\n",
    "                input.view(-1, C*self.num_splits, H, W), self.running_mean, self.running_var, \n",
    "                self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),\n",
    "                True, self.momentum, self.eps).view(N, C, H, W) \n",
    "        else:\n",
    "            return F.batch_norm(\n",
    "                input, self.running_mean[:self.num_features], self.running_var[:self.num_features], \n",
    "                self.weight, self.bias, False, self.momentum, self.eps)\n",
    "\n",
    "        \n",
    "class IdentityResidualBlock(nn.Module):\n",
    "\n",
    "    def __init__(self, num_channels, \n",
    "                 conv_ksize=3, conv_pad=1,\n",
    "                 gbn_num_splits=16):\n",
    "        super(IdentityResidualBlock, self).__init__()\n",
    "        self.res1 = nn.Sequential(\n",
    "            Conv2d(num_channels, num_channels, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),\n",
    "            GhostBatchNorm(num_channels, num_splits=gbn_num_splits, weight=False),\n",
    "            nn.CELU(alpha=0.3)         \n",
    "        )\n",
    "        self.res2 = nn.Sequential(\n",
    "            Conv2d(num_channels, num_channels, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),\n",
    "            GhostBatchNorm(num_channels, num_splits=gbn_num_splits, weight=False),\n",
    "            nn.CELU(alpha=0.3)    \n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        residual = x\n",
    "        x = self.res1(x)\n",
    "        x = self.res2(x)\n",
    "        return x + residual\n",
    "    \n",
    "\n",
    "# We override conv2d to get proper padding for kernel size = 2   \n",
    "class Conv2d(nn.Conv2d):\n",
    "    \n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super(Conv2d, self).__init__(*args, **kwargs)\n",
    "        if self.kernel_size == (2, 2):\n",
    "            self.forward = self.ksize_2_forward\n",
    "            self.ksize_2_padding = (0, self.padding[0], 0, self.padding[1])\n",
    "            self.padding = (0, 0)\n",
    "        \n",
    "    def ksize_2_forward(self, x):\n",
    "        x = F.pad(x, pad=self.ksize_2_padding)\n",
    "        return super(Conv2d, self).forward(x)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "eMJPLneEfaqT"
   },
   "source": [
    "class FastResNet(nn.Module):\n",
    "        \n",
    "    def __init__(self, num_classes=10, \n",
    "                 fmap_factor=64, conv_ksize=3, conv_pad=1, \n",
    "                 gbn_num_splits=512 // 32,                  \n",
    "                 classif_scale=0.0625):\n",
    "        super(FastResNet, self).__init__()\n",
    "                \n",
    "        self.prep = nn.Sequential(\n",
    "            Conv2d(3, fmap_factor, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),\n",
    "            GhostBatchNorm(fmap_factor, num_splits=gbn_num_splits, weight=False),\n",
    "            nn.CELU(alpha=0.3)\n",
    "        )\n",
    "\n",
    "        self.layer1 = nn.Sequential(\n",
    "            Conv2d(fmap_factor, fmap_factor * 2, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            GhostBatchNorm(fmap_factor * 2, num_splits=gbn_num_splits, weight=False),\n",
    "            nn.CELU(alpha=0.3),\n",
    "            IdentityResidualBlock(fmap_factor * 2,\n",
    "                                  conv_ksize=conv_ksize, conv_pad=conv_pad, \n",
    "                                  gbn_num_splits=gbn_num_splits)\n",
    "        )\n",
    "        \n",
    "        self.layer2 = nn.Sequential(\n",
    "            Conv2d(fmap_factor * 2, fmap_factor * 4, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            GhostBatchNorm(fmap_factor * 4, num_splits=gbn_num_splits, weight=False),\n",
    "            nn.CELU(alpha=0.3),            \n",
    "        )\n",
    "        \n",
    "        self.layer3 = nn.Sequential(\n",
    "            Conv2d(fmap_factor * 4, fmap_factor * 8, kernel_size=conv_ksize, padding=conv_pad, stride=1, bias=False),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            GhostBatchNorm(fmap_factor * 8, num_splits=gbn_num_splits, weight=False),\n",
    "            nn.CELU(alpha=0.3),\n",
    "            IdentityResidualBlock(fmap_factor * 8, \n",
    "                                  conv_ksize=conv_ksize, conv_pad=conv_pad, \n",
    "                                  gbn_num_splits=gbn_num_splits)\n",
    "        )\n",
    "        \n",
    "        self.pool = nn.MaxPool2d(kernel_size=4)\n",
    "        \n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(fmap_factor * 8, num_classes)\n",
    "        )\n",
    "        self.scale = torch.tensor(0.0625, requires_grad=False)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.prep(x)\n",
    "        x = self.layer1(x)\n",
    "        x = self.layer2(x)\n",
    "        x = self.layer3(x)\n",
    "        x = self.pool(x)\n",
    "        y = self.classifier(x)\n",
    "        return y * self.scale\n",
    "      "
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "mSR5-EYdfaqT"
   },
   "source": [
    "model = FastResNet(10, fmap_factor=64)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "JB5VpI-KfaqU"
   },
   "source": [
    "def print_num_params(model, display_all_modules=False):\n",
    "    total_num_params = 0\n",
    "    for n, p in model.named_parameters():\n",
    "        num_params = 1\n",
    "        for s in p.shape:\n",
    "            num_params *= s\n",
    "        if display_all_modules: print(f\"{n}: {num_params}\")\n",
    "        total_num_params += num_params\n",
    "    print(\"-\" * 50)\n",
    "    print(f\"Total number of parameters: {total_num_params:.2e}\")\n",
    "    \n",
    "\n",
    "print_num_params(model)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "tCJ9xVa1faqW"
   },
   "source": [
    "model = FastResNet(10, fmap_factor=64, conv_ksize=2)\n",
    "\n",
    "print_num_params(model)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tEe1st3Dfaqd"
   },
   "source": [
    "### Setup dataflow\n",
    "\n",
    "We will setup the dataflow using `torchvision` transformation and will not follow the suggestions of [bag-of-trick notebook](https://github.com/davidcpage/cifar10-fast/blob/master/bag_of_tricks.ipynb). Data augmentations used to transform the dataset are\n",
    "\n",
    "- Random Crop\n",
    "- Flip left-right\n",
    "- Cutout"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "LtUkVh21faqe"
   },
   "source": [
    "import torch\n",
    "from torchvision.transforms import Compose, Pad, RandomHorizontalFlip, RandomErasing, RandomCrop, Normalize"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "BuGB2zjWfaqe"
   },
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torchvision.transforms import ToTensor\n",
    "from torchvision.datasets.cifar import CIFAR10\n",
    "\n",
    "\n",
    "train_transform = Compose([\n",
    "    Pad(4),\n",
    "    RandomCrop(32),\n",
    "    RandomHorizontalFlip(),\n",
    "    ToTensor(),    \n",
    "    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n",
    "    RandomErasing(scale=(0.0625, 0.0625), ratio=(1.0, 1.0))\n",
    "])\n",
    "\n",
    "\n",
    "test_transform = Compose([\n",
    "    ToTensor(),    \n",
    "    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n",
    "])\n",
    "\n",
    "\n",
    "train_ds = CIFAR10(\"/tmp/cifar10\", train=True, download=True, transform=train_transform)\n",
    "test_ds = CIFAR10(\"/tmp/cifar10\", train=False, download=True, transform=train_transform)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "6r71eyznfaqf"
   },
   "source": [
    "def get_train_test_loaders():\n",
    "    train_loader = DataLoader(train_ds, batch_size=512, num_workers=10, shuffle=True, drop_last=True, pin_memory=True)\n",
    "    test_loader = DataLoader(test_ds, batch_size=512, num_workers=10, shuffle=False, drop_last=False, pin_memory=True)\n",
    "    return train_loader, test_loader"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NrCUZlGjfaqh"
   },
   "source": [
    "### Setup criterion, optimizer and lr scheduling\n",
    "\n",
    "Following cifar10-fast, we will use label smoothing trick for improving the training speed and generalization of neural nets in classification problems."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "8bX-LXoFfaqh"
   },
   "source": [
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "\n",
    "class CriterionWithLabelSmoothing(nn.Module):\n",
    "    \n",
    "    def __init__(self, criterion, alpha=0.2):\n",
    "        super(CriterionWithLabelSmoothing, self).__init__()\n",
    "        self.criterion = criterion\n",
    "        if self.criterion.reduction != 'none':\n",
    "            raise ValueError(\"Input criterion should have reduction equal none\")\n",
    "        self.alpha = alpha\n",
    "    \n",
    "    def forward(self, logits, targets):\n",
    "        loss = self.criterion(logits, targets)\n",
    "        log_probs = torch.log_softmax(logits, dim=1)\n",
    "        klloss = -log_probs.mean(dim=1)        \n",
    "        out = (1.0 - self.alpha) * loss + self.alpha * klloss\n",
    "        return out.mean(dim=0)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "oI02itLjfaqh"
   },
   "source": [
    "def get_criterion(alpha):\n",
    "    return CriterionWithLabelSmoothing(nn.CrossEntropyLoss(reduction='none'), alpha=0.2)\n",
    "\n",
    "\n",
    "def get_optimizer(model, momentum, weight_decay, nesterov):\n",
    "    biases = [p for n, p in model.named_parameters() if \"bias\" in n]\n",
    "    others = [p for n, p in model.named_parameters() if \"bias\" not in n]\n",
    "    return optim.SGD(\n",
    "        [{\"params\": others, \"lr\": 1.0, \"weight_decay\": weight_decay}, \n",
    "         {\"params\": biases, \"lr\": 1.0, \"weight_decay\": weight_decay / 64}], \n",
    "        momentum=momentum, nesterov=nesterov\n",
    "    )\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "I8rCwiNgfaqi"
   },
   "source": [
    "There is an implementation difference of current PyTorch SGD and SGD from cifar10-fast. The latter uses Sutskever et al implementation:\n",
    "```\n",
    "new_w = w + mu * v - lr * (dw + weight_decay * w)\n",
    "v = mu * prev_v - lr * (dw + weight_decay * w)\n",
    "```\n",
    "and PyTorch's one is \n",
    "```\n",
    "new_w = w - lr * (mu * v + dw + weight_decay * w)\n",
    "v = mu * prev_v + dw + weight_decay * w\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "YexZcC92faqi"
   },
   "source": [
    "from ignite.handlers import PiecewiseLinear, ParamGroupScheduler"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "BavMV6RHfaqj"
   },
   "source": [
    "def get_lr_scheduler(optimizer, lr_max_value, lr_max_value_epoch, num_epochs, epoch_length):\n",
    "    milestones_values = [\n",
    "        (0, 0.0), \n",
    "        (epoch_length * lr_max_value_epoch, lr_max_value), \n",
    "        (epoch_length * num_epochs - 1, 0.0)\n",
    "    ]\n",
    "    lr_scheduler1 = PiecewiseLinear(optimizer, \"lr\", milestones_values=milestones_values, param_group_index=0)\n",
    "\n",
    "    milestones_values = [\n",
    "        (0, 0.0), \n",
    "        (epoch_length * lr_max_value_epoch, lr_max_value * 64), \n",
    "        (epoch_length * num_epochs - 1, 0.0)\n",
    "    ]\n",
    "    lr_scheduler2 = PiecewiseLinear(optimizer, \"lr\", milestones_values=milestones_values, param_group_index=1)\n",
    "\n",
    "    lr_scheduler = ParamGroupScheduler(\n",
    "        [lr_scheduler1, lr_scheduler2],\n",
    "        [\"lr scheduler (non-biases)\", \"lr scheduler (biases)\"]\n",
    "    )\n",
    "    \n",
    "    return lr_scheduler"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "bDGjGeHofaqu"
   },
   "source": [
    "%matplotlib inline\n",
    "\n",
    "num_epochs = 25\n",
    "lr_max_value = 0.4\n",
    "milestones_values = [(0, 0.0), (num_epochs // 5, lr_max_value), (num_epochs - 1, 0.0)]\n",
    "\n",
    "PiecewiseLinear.plot_values(num_epochs, param_name=\"lr\", milestones_values=milestones_values)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "ydjoAss_faqv"
   },
   "source": [],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ubg08s_yfaqv"
   },
   "source": [
    "### Setup hyperparameter tuning\n",
    "\n",
    "Now we are ready to setup hyperparameter tuning to optimize the following parameters in order to get higher accuracy on test dataset while training limited by 12 epochs:\n",
    "\n",
    "- learning rate peak value: `[0.1, 1.0]`\n",
    "- SGD momentum: `[0.7, 1.0]`\n",
    "- weight decay: `[0.0, 1e-3]`\n",
    "- label smoothing `alpha`: `[0.1, 0.5]`\n",
    "- number of features (`fmap_factor`): `[16, 24, 32, 40, 48, 56, 64, 72, 80]`\n",
    "- convolution kernel size: `3` or `2`\n",
    "- ..."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "Yt3W860Hfaqw"
   },
   "source": [
    "from ax.plot.contour import plot_contour\n",
    "from ax.plot.trace import optimization_trace_single_method\n",
    "from ax.service.managed_loop import optimize\n",
    "from ax.utils.notebook.plotting import render, init_notebook_plotting\n",
    "\n",
    "init_notebook_plotting()"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6AHf4Tqwfaq2"
   },
   "source": [
    "First, we need to create evaluation function that receives experiment parameters and returns test accuracy.\n",
    "\n",
    "Input parameters search space is defined as a list of dictionaries that have the following required keys: \n",
    "- \"name\" - parameter name, \n",
    "- \"type\" - parameter type (\"range\", \"choice\" or \"fixed\"), \n",
    "- \"bounds\" for range parameters, \n",
    "- \"values\" for choice parameters, and \n",
    "- \"value\" for fixed parameters.\n",
    "\n",
    "Experiment parameters object provided for a single experiment is a dictionary `parameter name: value or values`. \n",
    "\n",
    "\n",
    "Links: \n",
    "- [Ax Parameters API](https://ax.dev/api/core.html#module-ax.core.parameter)\n",
    "- [Ax optimize function](https://ax.dev/api/service.html#ax.service.managed_loop.optimize)\n",
    "- [Ax parameters search space example](https://ax.dev/tutorials/gpei_hartmann_service.html#2.-Set-up-experiment)"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "d0OloBgffaq3"
   },
   "source": [
    "from ignite.engine import create_supervised_trainer, create_supervised_evaluator, Events, convert_tensor\n",
    "from ignite.metrics import Accuracy\n",
    "from ignite.handlers import TensorboardLogger, ProgressBar\n",
    "from ignite.handlers.tensorboard_logger import OutputHandler, OptimizerParamsHandler, GradsHistHandler, \\\n",
    "    global_step_from_engine"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "LS0hvziTfaq3"
   },
   "source": [
    "# Transfer batch to GPU and set floating-point 16\n",
    "def prepare_batch_fp16(batch, device=None, non_blocking=True):\n",
    "    x, y = batch\n",
    "    return (convert_tensor(x, device=device, non_blocking=non_blocking).half(),\n",
    "            convert_tensor(y, device=device, non_blocking=non_blocking))"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "ii-gTogTfaq6"
   },
   "source": [
    "torch.backends.cudnn.benchmark = True"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "Vecp6Lu1faq6"
   },
   "source": [
    "num_epochs = 17\n",
    "\n",
    "\n",
    "def run_experiment(parameters):\n",
    "    device = 'cuda'\n",
    "    fast_mode = parameters.get(\"fast_mode\", True)\n",
    "    \n",
    "    # setup model\n",
    "    model = FastResNet(\n",
    "        num_classes=10, \n",
    "        fmap_factor=parameters.get(\"fmap_factor\"), \n",
    "        conv_ksize=parameters.get(\"conv_ksize\"),\n",
    "        classif_scale=parameters.get(\"classif_scale\")\n",
    "    ).to(device).half()\n",
    "    \n",
    "    # setup dataloaders \n",
    "    train_loader, test_loader = get_train_test_loaders()\n",
    "    \n",
    "    # setup solver\n",
    "    criterion = get_criterion(parameters.get(\"alpha\")).to(device)\n",
    "    optimizer = get_optimizer(\n",
    "        model, \n",
    "        parameters.get(\"momentum\"), \n",
    "        parameters.get(\"weight_decay\"),\n",
    "        parameters.get(\"nesterov\")\n",
    "    )\n",
    "    lr_scheduler = get_lr_scheduler(\n",
    "        optimizer, \n",
    "        parameters.get(\"lr_max_value\"),\n",
    "        parameters.get(\"lr_max_value_epoch\"),        \n",
    "        num_epochs=num_epochs,\n",
    "        epoch_length=len(train_loader)\n",
    "    )\n",
    "    \n",
    "    # setup ignite trainer\n",
    "    trainer = create_supervised_trainer(model, optimizer, criterion, \n",
    "                                        device=device, non_blocking=True,\n",
    "                                        prepare_batch=prepare_batch_fp16)\n",
    "    \n",
    "    # setup learning rate scheduler\n",
    "    trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)\n",
    "    \n",
    "    # setup tensorboard logger\n",
    "    exp_log_name = f\"exp_{parameters.get('fmap_factor')}_{parameters.get('conv_ksize')}_\" + \\\n",
    "        f\"{parameters.get('alpha'):.2}_{parameters.get('lr_max_value'):.4}\"\n",
    "    tb_logger = TensorboardLogger(log_dir=f\"/tmp/tb_logs/{exp_log_name}\")\n",
    "    \n",
    "    if not fast_mode:\n",
    "        # - log learning rate\n",
    "        tb_logger.attach(trainer, OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)\n",
    "\n",
    "        # - log training batch loss\n",
    "        tb_logger.attach(trainer, OutputHandler(tag=\"training\", output_transform=lambda x: {\"batch loss\": x}), \n",
    "                         event_name=Events.ITERATION_COMPLETED)\n",
    "\n",
    "        # - log model grads\n",
    "        tb_logger.attach(trainer, GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED)    \n",
    "    \n",
    "        # setup a progress bar\n",
    "        ProgressBar().attach(trainer, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)    \n",
    "        \n",
    "    # setup evaluator\n",
    "    def output_transform(output):\n",
    "        y_pred, y = output\n",
    "        y_pred = y_pred.float()\n",
    "        return y_pred, y\n",
    "\n",
    "    metrics = {\n",
    "        \"test accuracy\": Accuracy(output_transform=output_transform)\n",
    "    }\n",
    "    evaluator = create_supervised_evaluator(model, metrics=metrics, \n",
    "                                            device=device, non_blocking=True, \n",
    "                                            prepare_batch=prepare_batch_fp16)\n",
    "    \n",
    "    # evaluate trained model each 3 epochs\n",
    "    @trainer.on(Events.EPOCH_COMPLETED)\n",
    "    def run_evaluation(engine):\n",
    "        c1 = (engine.state.epoch - 1) % 3 == 0\n",
    "        c2 = engine.state.epoch == engine.state.max_epochs\n",
    "        if (c1 and not fast_mode) or c2:\n",
    "            evaluator.run(test_loader)\n",
    "    \n",
    "    if not fast_mode:\n",
    "        # - log test accuracy\n",
    "        tb_logger.attach(evaluator, \n",
    "                         OutputHandler(tag=\"validation\", metric_names=\"all\", \n",
    "                                                  global_step_transform=global_step_from_engine(trainer)), \n",
    "                         event_name=Events.EPOCH_COMPLETED)\n",
    "\n",
    "    trainer.run(train_loader, max_epochs=num_epochs)    \n",
    "    test_acc = evaluator.state.metrics['test accuracy']\n",
    "    \n",
    "    # dump hparams/result to Tensorboard\n",
    "    tb_logger.writer.add_hparams(parameters, {'hparam/test_accuracy': test_acc})\n",
    "\n",
    "    tb_logger.close()    \n",
    "    return test_acc"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wHDqJ59Yfaq7"
   },
   "source": [
    "Original training configurations gives us the following result:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "rDitZVaxfaq7"
   },
   "source": [
    "batch_size = 512\n",
    "num_epochs = 20\n",
    "\n",
    "run_experiment(\n",
    "    parameters={\n",
    "        \"fmap_factor\": 64,\n",
    "        \"conv_ksize\": 3,\n",
    "        \"classif_scale\": 0.0625,\n",
    "        \"alpha\": 0.2,\n",
    "        \"momentum\": 0.9,\n",
    "        \"weight_decay\": 5e-4,\n",
    "        \"nesterov\": True,\n",
    "        \"lr_max_value\": 1.0,\n",
    "        \"lr_max_value_epoch\": num_epochs // 5,\n",
    "        \"fast_mode\": False\n",
    "    }\n",
    ")"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mExJHmKefaq8"
   },
   "source": [
    "#### Setup parameters search space"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "UovND_RDfaq8"
   },
   "source": [
    "parameters_space = [\n",
    "    {\n",
    "        \"name\": \"fmap_factor\",\n",
    "        \"type\": \"range\",\n",
    "        \"bounds\": [48, 80],\n",
    "    },\n",
    "    {\n",
    "        \"name\": \"conv_ksize\",\n",
    "        \"type\": \"choice\",\n",
    "        \"values\": [2, 3],\n",
    "    },\n",
    "    {\n",
    "        \"name\": \"classif_scale\",\n",
    "        \"type\": \"range\",\n",
    "        \"bounds\": [0.00625, 0.250],\n",
    "    },\n",
    "    {\n",
    "        \"name\": \"alpha\",\n",
    "        \"type\": \"range\",\n",
    "        \"bounds\": [0.1, 0.5],\n",
    "    },\n",
    "    {\n",
    "        \"name\": \"momentum\",\n",
    "        \"type\": \"range\",\n",
    "        \"bounds\": [0.7, 1.0],\n",
    "    },\n",
    "    {\n",
    "        \"name\": \"weight_decay\",\n",
    "        \"type\": \"range\",\n",
    "        \"bounds\": [1e-4, 1e-3],\n",
    "        \"value_type\": \"float\",\n",
    "    },    \n",
    "    {\n",
    "        \"name\": \"nesterov\",\n",
    "        \"type\": \"choice\",\n",
    "        \"values\": [True, False],\n",
    "    },\n",
    "    {\n",
    "        \"name\": \"lr_max_value\",\n",
    "        \"type\": \"range\",\n",
    "        \"bounds\": [0.1, 1.0],\n",
    "    },\n",
    "    {\n",
    "        \"name\": \"lr_max_value_epoch\",\n",
    "        \"type\": \"range\",\n",
    "        \"bounds\": [1, 10],\n",
    "    },\n",
    "]\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "qySwubl7faq8"
   },
   "source": [],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "z2VfeJJ6faq9"
   },
   "source": [
    "### Start tuning"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "ZCDlqoEIfaq9"
   },
   "source": [
    "num_epochs = exp_num_epochs = 20\n",
    "\n",
    "\n",
    "best_parameters, values, experiment, model = optimize(\n",
    "    parameters=parameters_space,\n",
    "    evaluation_function=run_experiment,\n",
    "    objective_name='test accuracy',\n",
    "    total_trials=30\n",
    ")\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SoK4IRuxfaq-"
   },
   "source": [
    "We found the best parameters that give the following outcome:"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "AeK3X7eofaq-"
   },
   "source": [
    "means, covariances = values\n",
    "print(f\"\\nBest parameters: {best_parameters}\\n\")\n",
    "print(f\"Test accuracy: {means} ± {covariances}\")"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5IE4nWH0faq-"
   },
   "source": [
    "Let's plot contours showing test accuracy as a function of the two hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "I4OExhm6faq_"
   },
   "source": [
    "render(plot_contour(model=model, param_x='lr_max_value', param_y='momentum', metric_name='test accuracy'))"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qC7r9vjlfarA"
   },
   "source": [
    "Let's retrain the model with best found parameters and compare with previous baseline: "
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "rVAR5tvJfarA"
   },
   "source": [
    "batch_size = 512\n",
    "num_epochs = 20\n",
    "\n",
    "best_parameters_copy = dict(best_parameters)\n",
    "best_parameters_copy['fast_mode'] = False\n",
    "\n",
    "run_experiment(\n",
    "    parameters=best_parameters_copy\n",
    ")"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "P0ob7uehfarC"
   },
   "source": [
    "In Tensorboard we can observer a tab \"HPARAMS\":\n",
    "\n",
    "![hparams](https://github.com/abdulelahsm/ignite/blob/update-tutorials/examples/notebooks/assets/ax_hparams.png?raw=1)"
   ]
  }
 ]
}
